from __future__ import absolute_import, division

import os
import glob
import random
import numpy as np
import xml.etree.ElementTree as ET
import torch
import torch.utils.data as data
from PIL import Image
from utils import load_stats, crop_pil
# For test
import torchvision.transforms as transforms

import matplotlib.pyplot as plt

# generate image pair to train/validate Siamfc tracker
class Pair(data.Dataset):

    def __init__(self, root_dir, subset='train', transform=None, config=None, pairs_per_video=25,
                 stats_path='ILSVRC2015.stats.mat', frame_range=100, rand_choice=True):
        """
        root_dir: the path to original ILSVRC2015 VID dataset
        subset: you can set 'train' or 'val'
        transform: be used to augment data
        config: be used to create labels and weights ....
        pairs_per_video: the number of image pairs generated by each video
        stats_path: load the mean and variance of x(search) and z(template) - color augmentation
        frame_range: maximum frame interval between x(search) and z(template)
        rand_choice: select x(search) and z(template) randomly
        """
        super(Pair, self).__init__()
        assert subset in ['train', 'val']
        self.root_dir = root_dir                    # the path to original ILSVRC2015_VID
        self.subset = subset
        self.seq_names = []                         # all video names in VID train/val dataset
        self.anno_dirs = []                         # all annotation paths in VID train/val dataset
        # train dataset
        if self.subset == 'train':
            # list all video paths in VID train dataset
            self.seq_dirs = sorted(glob.glob(os.path.join(self.root_dir, 'Data/VID/train/ILSVRC2015*/ILSVRC2015*')))
            for s in self.seq_dirs:
                s_splited = s.lstrip().rstrip().split('/')
                self.seq_names.append(os.path.basename(s))
                self.anno_dirs.append(os.path.join(self.root_dir, 'Annotations/VID/train', *s_splited[-2:]))
        # val dataset
        elif self.subset == 'val':
            # list all video paths in VID val dataset
            self.seq_dirs = sorted(glob.glob(os.path.join(self.root_dir, 'Data/VID/val/ILSVRC2015*')))
            for s in self.seq_dirs:
                self.seq_names.append(os.path.basename(s))
                self.anno_dirs.append(os.path.join(self.root_dir, 'Annotations/VID/val', self.seq_names[-1]))

        self.transform = transform
        self.stats = load_stats(stats_path)
        self.pairs_per_video = pairs_per_video
        self.frame_range = frame_range
        self.rand_choice = rand_choice

        n = len(self.seq_names)                                     # if train, n=3862; if val, n=555
        self.indices = np.arange(0, n, dtype=int)
        self.indices = np.tile(self.indices, self.pairs_per_video)  # the len of self.indices denote the number of image pairs

        # load parameters from config
        self.exemplarSize = config.exemplarSize  # template size(127)
        self.instanceSize = config.instanceSize  # search size(255)
        self.scoreSize = config.scoreSize        # score Size(17)
        self.context = config.context            # padding: 0.5
        self.rPos = config.rPos                  # 16
        self.rNeg = config.rNeg                  # 0
        self.totalStride = config.totalStride    # 8
        self.ignoreLabel = config.ignoreLabel    # -100


    def __getitem__(self, index):
        # if rand_choice, select a video in VID randomly
        if self.rand_choice:
            index = np.random.choice(self.indices)

        # list all annotation files(xml) in a selected video (index)
        anno_files = sorted(glob.glob(os.path.join(self.anno_dirs[index], '*.xml')))
        # list all objects in a selected video
        objects = [ET.ElementTree(file=f).findall('object') for f in anno_files]    # the len of objects is equal to the len of video
        # choose a object(track id) randomly
        track_ids, counts = np.unique([obj.find('trackid').text for group in objects for obj in group], return_counts=True)
        track_id = random.choice(track_ids[counts >= 2])

        frames = []                                              # all frames that contain selected object (track_id)
        anno = []                                                # all annotations of selected object (x, y, w, h)
        for f, group in enumerate(objects):                     # group denote all objects in a image(len_objects == len_video)
            for obj in group:                                   # determine whether .. is the selected object one by one
                if not obj.find('trackid').text == track_id:
                    continue
                frames.append(f)
                anno.append([
                    int(obj.find('bndbox/xmin').text),
                    int(obj.find('bndbox/ymin').text),
                    int(obj.find('bndbox/xmax').text),
                    int(obj.find('bndbox/ymax').text)])
        img_files = [os.path.join(self.seq_dirs[index], '%06d.JPEG' % f) for f in frames]   # all image paths with the selected object
        anno = np.array(anno)  # [xmin, ymin, xmax, ymax]
        anno[:, 2:] = anno[:, 2:] - anno[:, :2] + 1  # [xmin, ymin, w, h]

        rand_z, rand_x = self.sample_pair(len(img_files))                               # select z(template) and x(search) randomly
        img_z = Image.open(img_files[rand_z])                                           # open and read z(template)
        img_x = Image.open(img_files[rand_x])                                           # open and read x(search)
        if img_z.mode == 'L':                                                          # if image is gray, convert it to RGB
            img_z = img_z.convert('RGB')
            img_x = img_x.convert('RGB')

        # Test
 #       plt.figure('z')
 #       plt.imshow(img_z)
 #       plt.show()
 #       plt.figure('x')
 #       plt.imshow(img_x)
 #       plt.show()

        bndbox_z = anno[rand_z, :]                                                      # read annotation of z(template)
        bndbox_x = anno[rand_x, :]                                                      # read annotation of x(search)
        crop_z = self.crop(img_z, bndbox_z, self.exemplarSize)                          # crop template patch from img_z, then resize [127, 127]
        crop_x = self.crop(img_x, bndbox_x, self.instanceSize)                          # crop search patch from img_x, then resize [255, 255]

        # Test
 #       plt.figure('crop_z')
 #       plt.imshow(crop_z)
 #       plt.show()
 #       plt.figure('crop_x')
 #       plt.imshow(crop_x)
 #       plt.show()

        labels, weights = self.create_labels()                                          # create corresponding labels and weights
        labels = torch.from_numpy(labels).float()                                       # convert numpy to Tensor
        weights = torch.from_numpy(weights).float()

        # data augmentation
        # if train, transform contains RandomHorizontalFlip, ToTensor
        # if val, transform contains ToTensor
        crop_z = self.transform(crop_z) * 255.0
        crop_x = self.transform(crop_x) * 255.0
        # color augmentation - only for train dataset
        if self.subset == 'train':
            offset_z = np.reshape(np.dot(self.stats.rgb_variance_z, np.random.randn(3, 1)), (3, 1, 1))
            offset_x = np.reshape(np.dot(self.stats.rgb_variance_x, np.random.randn(3, 1)), (3, 1, 1))
            crop_z += torch.from_numpy(offset_z).float()
            crop_x += torch.from_numpy(offset_x).float()
            crop_z = torch.clamp(crop_z, 0.0, 255.0)
            crop_x = torch.clamp(crop_x, 0.0, 255.0)

        return crop_z, crop_x, labels, weights

    def __len__(self):
        return len(self.indices)


    def sample_pair(self, n):
        rand_z = np.random.randint(n)           # select a image randomly as z(template)
        if self.frame_range == 0:
            return rand_z, rand_z
        possible_x = np.arange(rand_z - self.frame_range, rand_z + self.frame_range)    # get possible search(x) according to frame_range
        possible_x = np.intersect1d(possible_x, np.arange(n))                           # remove impossible x(search)
        possible_x = possible_x[possible_x != rand_z]                                   # z(template) and x(search) cannot be same
        rand_x = np.random.choice(possible_x)                                           # select x from possible_x randomly
        return rand_z, rand_x


    # crop the image patch of the specified size - template(127), search(255)
    def crop(self, image, bndbox, out_size):
        center = bndbox[:2] + bndbox[2:] / 2
        size = bndbox[2:]

        context = self.context * size.sum()
        patch_sz = out_size / self.exemplarSize * \
                       np.sqrt((size + context).prod())

        return crop_pil(image, center, patch_sz, out_size=out_size)


    # create labels and weights. This section is similar to Matlab version of Siamfc
    def create_labels(self):
        labels = self.create_logisticloss_labels()
        weights = np.zeros_like(labels)

        pos_num = np.sum(labels == 1)
        neg_num = np.sum(labels == 0)
        weights[labels == 1] = 0.5 / pos_num
        weights[labels == 0] = 0.5 / neg_num
        #weights *= pos_num + neg_num

        labels = labels[np.newaxis, :]
        weights = weights[np.newaxis, :]

        return labels, weights

    def create_logisticloss_labels(self):
        label_sz = self.scoreSize
        r_pos = self.rPos / self.totalStride
        r_neg = self.rNeg / self.totalStride
        labels = np.zeros((label_sz, label_sz))

        for r in range(label_sz):
            for c in range(label_sz):
                dist = np.sqrt((r - label_sz // 2) ** 2 +
                                (c - label_sz // 2) ** 2)
                if dist <= r_pos:
                    labels[r, c] = 1
                elif dist <= r_neg:
                    labels[r, c] = self.ignoreLabel
                else:
                    labels[r, c] = 0

        return labels


# Test Code
class train_config(object):
    exemplarSize = 127
    instanceSize = 255
    scoreSize = 17
    context = 0.5
    rPos = 16
    rNeg = 0
    totalStride = 8
    ignoreLabel = -100

if __name__ == '__main__':
    config = train_config()
    transforms_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()])
    pair_train = Pair(root_dir='/home/pylab/LHWorkspace/ILSVRC2015', subset='train', transform=transforms_train,
                      config=config)
    print(pair_train.__len__())
    pair_train.__getitem__(20)

    transforms_val = transforms.ToTensor()
    pair_val = Pair(root_dir='/home/pylab/LHWorkspace/ILSVRC2015', subset='val', transform=transforms_val,
                      config=config)
    print(pair_val.__len__())
    pair_val.__getitem__(20)